"""Phase 3c: Gravity (lightweight) + time variation + quadrant + oil profiles."""

import sys, os
import pandas as pd
import numpy as np
sys.path.insert(0, '/mnt/c/demographics_capital_flows/multilateral/src')
from model import PanelGLS

OUT = '/mnt/c/demographics_capital_flows/eu_demographics/output/tables'

panel = pd.read_csv('/mnt/c/demographics_capital_flows/multilateral/140_country/data/processed/full_panel_with_resources.csv')
panel = panel[panel['year'] <= 2024].copy()
panel['commodity_exporter'] = (panel['resource_rents_gdp'] >= 10).astype(float)
panel['Z1_x_resource'] = panel['Z_1'] * panel['resource_rents_gdp']

controls = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'rgdp_growth', 'log_rel_opw', 'kaopen']
base_vars = ['Z_1', 'Z_2', 'Z_3'] + controls

def run_model(name, dep, indep, data, entity_col='iso3', quiet=False):
    est = data[[entity_col, 'year', dep] + indep].dropna()
    if len(est) < 50:
        print(f"  {name}: insufficient obs ({len(est)})")
        return None
    m = PanelGLS()
    m.fit(est[dep].values, est[indep].values, est[entity_col].values, est['year'].values)
    if not quiet:
        print(f"\n  {name}: N={m.n_obs}, R²={m.r_squared:.4f}")
        for i, v in enumerate(indep):
            s = '***' if m.pvalues[i]<0.01 else '**' if m.pvalues[i]<0.05 else '*' if m.pvalues[i]<0.1 else ''
            print(f"    {v:35s} {m.beta[i]:10.4f} ({m.se[i]:.4f}){s}")
    return {'model': name, 'n_obs': m.n_obs, 'r_squared': m.r_squared,
            'vars': indep, 'beta': m.beta, 'se': m.se, 'pvalues': m.pvalues}

def st(p): return '***' if p<0.01 else '**' if p<0.05 else '*' if p<0.1 else ''

# ═══════════════════════════════════════════════════════════════════════
# SECTION 7: BILATERAL — use country-level aggregates instead
# ═══════════════════════════════════════════════════════════════════════
print("=" * 70)
print("SECTION 7: Bilateral Flows — Country-Level Aggregates")
print("=" * 70)

bilateral = pd.read_csv('/mnt/c/demographics_capital_flows/gravity_bilateral/data/processed/bilateral_panel.csv')
rents = pd.read_csv('/mnt/c/demographics_capital_flows/multilateral/140_country/data/raw/wdi_resource_rents.csv')

# Aggregate: total outward portfolio by origin country-year
for fv in ['portfolio_total', 'portfolio_debt', 'fdi_outward']:
    if fv not in bilateral.columns:
        continue
    # Outward flows aggregated by origin
    origin_agg = bilateral.groupby(['iso_o', 'year']).agg(
        total_outward=(fv, 'sum'),
        n_destinations=(fv, 'count')
    ).reset_index().rename(columns={'iso_o': 'iso3'})

    origin_agg[f'log_{fv}_out'] = np.log(origin_agg['total_outward'].clip(lower=1))

    # Merge with demographics and resource rents
    origin_agg = origin_agg.merge(
        panel[['iso3', 'year', 'Z_1', 'Z_2', 'Z_3'] + controls + ['resource_rents_gdp']].drop_duplicates(),
        on=['iso3', 'year'], how='left')

    origin_agg['Z1_x_resource'] = origin_agg['Z_1'] * origin_agg['resource_rents_gdp']
    origin_agg['commodity_exporter'] = (origin_agg['resource_rents_gdp'] >= 10).astype(float)
    origin_agg['Z1_x_commodity'] = origin_agg['Z_1'] * origin_agg['commodity_exporter']

    dep = f'log_{fv}_out'
    print(f"\n--- {fv} (outward, country-level) ---")

    run_model(f'{fv}: Z₁ baseline', dep,
             ['Z_1', 'Z_2', 'Z_3'] + controls, origin_agg)

    run_model(f'{fv}: + Z₁ × Resource', dep,
             ['Z_1', 'Z_2', 'Z_3'] + controls + ['resource_rents_gdp', 'Z1_x_resource'],
             origin_agg)

    # Inward flows aggregated by destination
    dest_agg = bilateral.groupby(['iso_d', 'year']).agg(
        total_inward=(fv, 'sum'),
    ).reset_index().rename(columns={'iso_d': 'iso3'})

    dest_agg[f'log_{fv}_in'] = np.log(dest_agg['total_inward'].clip(lower=1))

    dest_agg = dest_agg.merge(
        panel[['iso3', 'year', 'Z_1', 'Z_2', 'Z_3'] + controls + ['resource_rents_gdp']].drop_duplicates(),
        on=['iso3', 'year'], how='left')

    dest_agg['Z1_x_resource'] = dest_agg['Z_1'] * dest_agg['resource_rents_gdp']

    dep_in = f'log_{fv}_in'
    run_model(f'{fv} inward: + Z₁ × Resource', dep_in,
             ['Z_1', 'Z_2', 'Z_3'] + controls + ['resource_rents_gdp', 'Z1_x_resource'],
             dest_agg)

# ═══════════════════════════════════════════════════════════════════════
# SECTION 8: TIME VARIATION
# ═══════════════════════════════════════════════════════════════════════
print("\n\n" + "=" * 70)
print("SECTION 8: Time Variation in Commodity-Demographic Interaction")
print("=" * 70)

for period, (y1, y2) in [('1980-1999', (1980, 1999)),
                          ('2000-2009', (2000, 2009)),
                          ('2010-2021', (2010, 2021))]:
    sub = panel[(panel['year'] >= y1) & (panel['year'] <= y2)]
    r = run_model(f'Period {period}', 'ca_gdp',
                  base_vars + ['resource_rents_gdp', 'Z1_x_resource'], sub, quiet=True)
    if r:
        z1_i = r['vars'].index('Z_1')
        int_i = r['vars'].index('Z1_x_resource')
        rr_i = r['vars'].index('resource_rents_gdp')
        print(f"  {period}: N={r['n_obs']}, R²={r['r_squared']:.4f}, "
              f"Z₁={r['beta'][z1_i]:.3f}{st(r['pvalues'][z1_i])}, "
              f"Z₁×Res={r['beta'][int_i]:.4f}{st(r['pvalues'][int_i])}, "
              f"Rents={r['beta'][rr_i]:.3f}{st(r['pvalues'][rr_i])}")

# ═══════════════════════════════════════════════════════════════════════
# SECTION 9: QUADRANT REGRESSIONS
# ═══════════════════════════════════════════════════════════════════════
print("\n\n" + "=" * 70)
print("SECTION 9: Country Classification — Demographic-Commodity Matrix")
print("=" * 70)

recent = panel[(panel['year'] >= 2015) & (panel['year'] <= 2021)]
cc = recent.groupby('iso3').agg(
    mean_Z1=('Z_1', 'mean'), mean_rents=('resource_rents_gdp', 'mean'),
    mean_ca=('ca_gdp', 'mean')
).dropna(subset=['mean_Z1', 'mean_rents'])

cc['old'] = cc['mean_Z1'] > 0
cc['commodity'] = cc['mean_rents'] >= 10

quadrants = {
    (True, True): 'Old Commodity',
    (True, False): 'Old Non-Commodity',
    (False, True): 'Young Commodity',
    (False, False): 'Young Non-Commodity'
}

cc['quadrant'] = cc.apply(lambda r: quadrants[(r['old'], r['commodity'])], axis=1)

for q in quadrants.values():
    sub = cc[cc['quadrant'] == q]
    countries = sub.index.tolist()
    sub_panel = panel[panel['iso3'].isin(countries)]

    r = run_model(q, 'ca_gdp', base_vars, sub_panel, quiet=True)
    if r:
        z1_i = r['vars'].index('Z_1')
        print(f"\n  {q:25s}: {len(countries):3d} countries, N={r['n_obs']}, "
              f"Z₁={r['beta'][z1_i]:.3f}{st(r['pvalues'][z1_i])}, R²={r['r_squared']:.4f}")
        print(f"    Mean CA: {sub['mean_ca'].mean():.2f}, Mean Rents: {sub['mean_rents'].mean():.1f}%")
        if len(countries) <= 20:
            print(f"    Countries: {', '.join(sorted(countries))}")
        else:
            print(f"    Example: {', '.join(sorted(countries)[:10])}...")

# ═══════════════════════════════════════════════════════════════════════
# SECTION 10: OIL EXPORTER PROFILES
# ═══════════════════════════════════════════════════════════════════════
print("\n\n" + "=" * 70)
print("SECTION 10: Oil Exporter Profiles")
print("=" * 70)

oil_countries = ['SAU', 'KWT', 'ARE', 'QAT', 'OMN', 'BHR',
                 'RUS', 'NOR', 'IRN', 'IRQ', 'DZA', 'LBY',
                 'AGO', 'NGA', 'KAZ', 'AZE', 'VEN']

snap = panel[(panel['year'] >= 2018) & (panel['year'] <= 2021)]
oil_snap = snap[snap['iso3'].isin(oil_countries)].groupby('iso3').agg(
    Z1=('Z_1', 'mean'), OADR=('old_dep', 'mean'),
    CA=('ca_gdp', 'mean'), NFA=('nfa_gdp', 'mean'),
    Rents=('resource_rents_gdp', 'mean'), Fiscal=('fiscal_bal_gdp', 'mean')
).dropna(subset=['Rents'])

full_proj = pd.read_csv('/mnt/c/demographics_capital_flows/multilateral/140_country/data/processed/full_panel.csv')
for yr in [2030, 2050]:
    proj = full_proj[full_proj['year'] == yr][['iso3', 'Z_1', 'old_dep']].rename(
        columns={'Z_1': f'Z1_{yr}', 'old_dep': f'OADR_{yr}'})
    oil_snap = oil_snap.merge(proj, left_index=True, right_on='iso3').set_index('iso3')

oil_snap = oil_snap.sort_values('Rents', ascending=False)

print(f"\n{'Country':6s} {'Rents%':>7s} {'Z₁':>7s} {'Z₁_30':>7s} {'Z₁_50':>7s} "
      f"{'OADR':>6s} {'OADR_50':>8s} {'CA':>6s} {'NFA':>6s} {'Fiscal':>7s}")
print("-" * 78)
for iso, row in oil_snap.iterrows():
    print(f"{iso:6s} {row['Rents']:7.1f} {row['Z1']:7.3f} {row.get('Z1_2030', np.nan):7.3f} "
          f"{row.get('Z1_2050', np.nan):7.3f} {row['OADR']*100:5.1f}% "
          f"{row.get('OADR_2050', np.nan)*100:7.1f}% {row['CA']:6.1f} "
          f"{row.get('NFA', np.nan):6.1f} {row['Fiscal']:7.1f}")

print("\n✓ Phase 3c complete.")
